{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import print_function, division"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# If you are running on a server, launch xvfb to record game videos",
    "\n",
    "# Please make sure you have xvfb installed",
    "\n",
    "import os",
    "\n",
    "if type(os.environ.get(\"DISPLAY\")) is not str or len(os.environ.get(\"DISPLAY\")) == 0:",
    "\n",
    "    !bash ../xvfb start",
    "\n",
    "    os.environ['DISPLAY'] = ':1'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you are new to this course and want more instructions on how to set up environement and all the libs (docker / windows / gpu / blas / etc.), you could read [vital instructions here](https://github.com/yandexdataschool/Practical_RL/issues/1#issue-202648393). \n",
    "\n",
    "Please make sure that your have bleeding edge versions of Theano, Lasagne and Agentnet."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# General purpose libs import"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt",
    "\n",
    "import numpy as np",
    "\n",
    "%matplotlib inline",
    "\n",
    "from timeit import default_timer as timer",
    "\n",
    "\n",
    "from IPython.core import display"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# if you have  GPU uncomment the line below",
    "\n",
    "%env THEANO_FLAGS = device = gpu0, floatX = float32"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Universal collection of a gentleman:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gym",
    "\n",
    "\n",
    "from agentnet.agent import Agent",
    "\n",
    "from agentnet.experiments.openai_gym.wrappers import PreprocessImage",
    "\n",
    "from agentnet.memory import WindowAugmentation, LSTMCell, GRUCell",
    "\n",
    "from agentnet.target_network import TargetNetwork",
    "\n",
    "from agentnet.resolver import EpsilonGreedyResolver, ProbabilisticResolver",
    "\n",
    "from agentnet.experiments.openai_gym.pool import EnvPool",
    "\n",
    "from agentnet.learning import qlearning",
    "\n",
    "\n",
    "import theano",
    "\n",
    "import theano.tensor as T",
    "\n",
    "\n",
    "import lasagne",
    "\n",
    "from lasagne.layers import DenseLayer, Conv2DLayer, InputLayer, NonlinearityLayer",
    "\n",
    "from lasagne.layers import batch_norm, get_all_params, get_output, reshape, concat, dropout",
    "\n",
    "from lasagne.nonlinearities import rectify, leaky_rectify, elu, tanh, softmax"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Helper function definitions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Downsample image, and crop it, showing only the most useful part of image. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_env():",
    "\n",
    "    env = gym.make(\"KungFuMaster-v0\")",
    "\n",
    "    env = PreprocessImage(env, height=64, width=64,",
    "\n",
    "                          grayscale=True, crop=lambda img: img[60:-30, 7:])",
    "\n",
    "    return env"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Function for tracking performance while training "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def eval_and_plot(rewards, epoch_counter, pool, target_score, th_times, loop_times):",
    "\n",
    "    rewards[epoch_counter] = np.mean(pool.evaluate(",
    "\n",
    "        n_games=N_EVAL_GAMES, record_video=False, verbose=False))",
    "\n",
    "    info_string = \"Time (DL/All) {:.1f}/{:.1f}  epoch={}, mean_score={:.2f}\"",
    "\n",
    "    info_string = info_string.format(np.mean(th_times), np.mean(loop_times),",
    "\n",
    "                                     epoch_counter, np.mean(rewards[epoch_counter]))",
    "\n",
    "    plt.figure(figsize=(8, 5))",
    "\n",
    "    plt.plot([rewards[i] for i in sorted(rewards.keys())])",
    "\n",
    "    plt.grid()",
    "\n",
    "    plt.ylabel(\"Mean reward over evaluation games\")",
    "\n",
    "    plt.title(info_string)",
    "\n",
    "    plt.show()",
    "\n",
    "    display.clear_output(wait=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Experiment setup\n",
    "Here we basically just load the game and check that it works"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "env = gym.make('KungFuMaster-v0')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(env.env.get_action_meanings())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(env.reset())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "env = make_env()",
    "\n",
    "plt.imshow(np.squeeze(env.reset()), interpolation='none', cmap='gray')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Global constants definition"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "All hyperparameters (except number of layers and neurons) are declared here as upper case letters along with global varaibles."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "N_ACTIONS = env.action_space.n",
    "\n",
    "OBS_SHAPE = env.observation_space.shape",
    "\n",
    "OBS_CHANNELS, OBS_HEIGHT, OBS_WIDTH = OBS_SHAPE",
    "\n",
    "\n",
    "# These 4 constanst were shown to lead to nearly state of the art on kung-fu master game",
    "\n",
    "N_SIMULTANEOUS_GAMES = 10  # this is also known as number of agents in exp_replay_pool",
    "\n",
    "SEQ_LENGTH = 25",
    "\n",
    "\n",
    "EVAL_EVERY_N_ITER = 100",
    "\n",
    "N_EVAL_GAMES = 2",
    "\n",
    "\n",
    "N_FRAMES_IN_BUFFER = 4  # number of consequent frames to feed in CNN"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# A2C with memory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "observation_layer = InputLayer((None,) + OBS_SHAPE)",
    "\n",
    "prev_wnd = InputLayer(",
    "\n",
    "    [None, N_FRAMES_IN_BUFFER, OBS_CHANNELS, OBS_HEIGHT, OBS_WIDTH])",
    "\n",
    "new_wnd = WindowAugmentation(observation_layer, prev_wnd)",
    "\n",
    "wnd_reshape = reshape(",
    "\n",
    "    new_wnd, [-1,  N_FRAMES_IN_BUFFER * OBS_CHANNELS, OBS_HEIGHT, OBS_WIDTH])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# TYPE YOUR CODE HERE",
    "\n",
    "# provide the main body of the network : first three convolutional layers and dense one on top",
    "\n",
    "# you may want to change nonlinearity - feel free to do this",
    "\n",
    "# note that we have changed filter size here because of reduced image width and height compared to those in papers",
    "\n",
    "conv1 = Conv2DLayer(wnd_reshape, ...)",
    "\n",
    "...",
    "\n",
    "dense = Dense(...)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# YOUR CODE HERE",
    "\n",
    "# define 256 neuron LSTM cell:",
    "\n",
    "# - define two input layers each of n_lstm_cells (maybe 256 is a good baseline) neurons",
    "\n",
    "# - feed into `LSTMcell` this two layers and",
    "\n",
    "#   input layer (last `Dense` in case of A2C+LSTM) as additional third parameter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "neck_layer = concat([ < dense layer before lstm > , < output of LSTM layer > ]) # network neck "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# YOUR CODE HERE",
    "\n",
    "# define actors head as",
    "\n",
    "# - logits_layer – dense(neck) with nonlinearity=None",
    "\n",
    "# - policy layer – softmax over logits_layer",
    "\n",
    "........",
    "\n",
    "action_layer = ProbabilisticResolver(policy_layer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# critic head",
    "\n",
    "V_layer = DenseLayer(neck_layer, 1, nonlinearity=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# YOUR CODE HERE",
    "\n",
    "# `observation_layers` is input layer to NN, as usual",
    "\n",
    "# `policy_estimators` should include 1) logits_layer and 2) V_layer",
    "\n",
    "# `agent_states` is a dictionary of {new_value: old_value}. You should bother to update",
    "\n",
    "#    a) prev window (input buffer, prev_wnd)  b) previous LSTM cell state  c) output of LSTM cell",
    "\n",
    "# `action_layers` is action_layer, as usual : )",
    "\n",
    "agent = Agent(....)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# may need to adjust (increasing N_SIMULTANEOUS_GAMES is usually a good idea)",
    "\n",
    "pool = EnvPool(agent, make_env, n_games=N_SIMULTANEOUS_GAMES)",
    "\n",
    "replay = pool.experience_replay"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_, _, _, action_seq, (logits_seq, V_seq) = agent.get_sessions(",
    "\n",
    "    replay,",
    "\n",
    "    session_length=SEQ_LENGTH,",
    "\n",
    "    experience_replay=True",
    "\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# compute pi(a|s) and log(pi(a|s)) manually [use logsoftmax]",
    "\n",
    "# we can't guarantee that theano optimizes logsoftmax automatically since it's still in dev",
    "\n",
    "# for more info see (https://github.com/Theano/Theano/issues/2944 of 2015 year)",
    "\n",
    "\n",
    "# logits_seq.shape is (batch_size, SEQ_LENGTH, N_ACTIONS)",
    "\n",
    "logits_flat = logits_seq.reshape([-1, N_ACTIONS])",
    "\n",
    "policy_seq = T.nnet.softmax(logits_flat).reshape(logits_seq.shape)",
    "\n",
    "logpolicy_seq = T.nnet.logsoftmax(logits_flat).reshape(logits_seq.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get policy gradient",
    "\n",
    "from agentnet.learning import a2c",
    "\n",
    "elwise_actor_loss, elwise_critic_loss = a2c.get_elementwise_objective(",
    "\n",
    "    policy=logpolicy_seq,",
    "\n",
    "    treat_policy_as_logpolicy=True,",
    "\n",
    "    state_values=V_seq[:, :, 0],",
    "\n",
    "    actions=replay.actions[0],",
    "\n",
    "    rewards=replay.rewards/10,",
    "\n",
    "    is_alive=replay.is_alive,",
    "\n",
    "    gamma_or_gammas=0.99,",
    "\n",
    "    n_steps=None,",
    "\n",
    "    return_separate=True",
    "\n",
    ")",
    "\n",
    "\n",
    "# add losses with magic numbers",
    "\n",
    "# (you can change them more or less harmlessly, this usually just makes learning faster/slower)",
    "\n",
    "# actor and critic multipliers were selected guided by prior knowledge",
    "\n",
    "# entropy / regularization multipliers were tuned with logscale gridsearch",
    "\n",
    "# NB: regularization affects exploration",
    "\n",
    "reg_logits = T.mean(logits_seq ** 2)",
    "\n",
    "reg_entropy = T.mean(T.sum(policy_seq * logpolicy_seq, axis=-1))",
    "\n",
    "loss = 0.1 * elwise_actor_loss.mean() + 0.25 * elwise_critic_loss.mean() + \\",
    "\n",
    "    1e-3 * reg_entropy + 1e-3 * reg_logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute weight updates, clip by norm for stability",
    "\n",
    "weights = lasagne.layers.get_all_params(",
    "\n",
    "    [V_layer, policy_layer], trainable=True)",
    "\n",
    "grads = T.grad(loss, weights)",
    "\n",
    "grads = lasagne.updates.total_norm_constraint(grads, 10)",
    "\n",
    "updates = lasagne.updates.adam(grads, weights)",
    "\n",
    "train_step = theano.function([], loss, updates=updates)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "epoch_counter = 1  # starting epoch",
    "\n",
    "rewards = {}  # full game rewards",
    "\n",
    "target_score = 10000",
    "\n",
    "loss, eval_rewards = 0, []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "untrained_reward = np.mean(pool.evaluate(",
    "\n",
    "    n_games=5, record_video=False, verbose=False))",
    "\n",
    "untrained_reward"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# IF you feel disgust about stderr messages due to pool.evaluate() execution",
    "\n",
    "# which pollutes output of jupyter cell, you could do one of the following:",
    "\n",
    "# 1. use warnings.filterwarnings(\"ignore\")",
    "\n",
    "# 2. use cell magic %%capture",
    "\n",
    "# 3. simply redirect stderr to /dev/null with command",
    "\n",
    "#    import os, sys",
    "\n",
    "#    stder_old = sys.stderr",
    "\n",
    "#    sys.stderr = open(os.devnull, 'w')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "th_times, loop_times = [], []",
    "\n",
    "for i in range(2000):",
    "\n",
    "    loop_starts = timer()",
    "\n",
    "    pool.update(SEQ_LENGTH)",
    "\n",
    "\n",
    "    train_starts = timer()",
    "\n",
    "\n",
    "    # YOUR CODE HERE : train network (actor and critic)",
    "\n",
    "    raise NotImplementedError",
    "\n",
    "\n",
    "    th_times.append(timer() - train_starts)",
    "\n",
    "    epoch_counter += 1",
    "\n",
    "    loop_times.append(timer() - loop_starts)",
    "\n",
    "\n",
    "    # You may want to set EVAL_EVERY_N_ITER=1 for the time being",
    "\n",
    "    if epoch_counter % EVAL_EVERY_N_ITER == 0:",
    "\n",
    "        eval_and_plot(rewards, epoch_counter, pool,",
    "\n",
    "                      target_score, th_times, loop_times)",
    "\n",
    "        if rewards[epoch_counter] >= target_score:",
    "\n",
    "            print(\"VICTORY!\")",
    "\n",
    "            break",
    "\n",
    "        th_times, loop_times = [], []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_and_plot(rewards, epoch_counter, pool, target_score, th_times, loop_times)"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
